### PGD implementation

import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F
# import torchvision
from torch.autograd import Variable
import torch.optim as optim
# import logging
import wandb
from tqdm import tqdm


def attack_pgd(args, model, vl_dl, attack_params):
    clamp_lst = get_clamp(args)

    total_num = len(vl_dl.dataset)
    natural_num_correct = 0
    adv_num_correct = 0
    for data, target in tqdm(vl_dl):
        data = data.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        X = Variable(data, requires_grad=True)
        y = Variable(target)

        is_correct_natural, is_correct_adv, _ = pgd(
            model, X, y, clamp_lst,
            epsilon=attack_params['epsilon'],
            num_steps=attack_params['num_steps'],
            step_size=attack_params['step_size'],
            random_start=attack_params['random_start'])

        natural_num_correct += is_correct_natural.sum()
        adv_num_correct += is_correct_adv[:,-1].sum()

    print(f"Accuracy after {adv_num_correct/total_num}")
    wandb.log({**{ 'pgd/count': total_num,
            'pgd/natural_accuracy': natural_num_correct / total_num,
            'pgd/adv_accuracy': adv_num_correct / total_num,
            }}, commit=False)


def pgd(model, X, y, clamp_lst, epsilon=8 / 255, num_steps=20, step_size=0.01,
        random_start=True):
    out = model(X) ## crop 안한 single (patch) 입력이 들어옴
    is_correct_nat = (out.max(1)[1] == y).float().cpu().numpy()
    perturbation = torch.zeros_like(X, requires_grad=True)
    ##여기에서 pert = X
    clamp, p_clamp = clamp_lst

    if random_start:
        perturbation = torch.rand_like(X, requires_grad=True)
        perturbation.data = perturbation.data * 2 * epsilon - epsilon

    is_correct_adv = []
    opt = optim.SGD([perturbation], lr=1e-3)  # This is just to clear the grad
    epsd = epsilon * 2

    for i in range(num_steps):
        opt.zero_grad()
        with torch.enable_grad():
            loss = nn.CrossEntropyLoss()(model(X + perturbation), y)
        loss.backward()

        perturbation.data = (
            perturbation + step_size * perturbation.grad.detach().sign()).clamp(
            -epsd, epsd)
        perturbation.data = torch.min(torch.max(perturbation.detach(), -1-X),
                                      1-X) # clip X+delta to [-1,1]

        perturbation.data = p_clamp(X, perturbation.detach())
        x_tmp = torch.clamp(X.data + perturbation.data, 0, 1.0)
        X_pgd = Variable(x_tmp, requires_grad=False)
        logits = model(X_pgd)
        correct_adv = (logits.max(1)[1] == y).float().cpu().numpy()
        is_correct_adv.append(np.reshape(correct_adv, [-1, 1]))
    is_correct_adv = np.concatenate(is_correct_adv, axis=1)
    return is_correct_nat, is_correct_adv, perturbation
